//
// Copyright 2023 The GUAC Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package backend

import (
	"context"
	stdsql "database/sql"
	"fmt"
	"strings"

	"entgo.io/contrib/entgql"
	"entgo.io/ent/dialect/sql"
	"github.com/google/uuid"
	"github.com/guacsec/guac/internal/testing/ptrfrom"
	"github.com/guacsec/guac/pkg/assembler/backends/ent"
	"github.com/guacsec/guac/pkg/assembler/backends/ent/predicate"
	"github.com/guacsec/guac/pkg/assembler/backends/ent/vulnerabilityid"
	"github.com/guacsec/guac/pkg/assembler/graphql/model"
	"github.com/guacsec/guac/pkg/assembler/helpers"
	"github.com/pkg/errors"
	"github.com/vektah/gqlparser/v2/gqlerror"
)

const (
	vulnTypeString = "vulnerability_types"
	NoVuln         = "novuln"
)

func vulnTypeGlobalID(id string) string {
	return toGlobalID(vulnTypeString, id)
}

func vulnIDGlobalID(id string) string {
	return toGlobalID(vulnerabilityid.Table, id)
}

func (b *EntBackend) IngestVulnerability(ctx context.Context, vuln model.IDorVulnerabilityInput) (*model.VulnerabilityIDs, error) {
	id, txErr := WithinTX(ctx, b.client, func(ctx context.Context) (*model.VulnerabilityIDs, error) {
		return upsertVulnerability(ctx, ent.TxFromContext(ctx), vuln)
	})
	if txErr != nil {
		return nil, txErr
	}

	return id, nil
}

func (b *EntBackend) IngestVulnerabilities(ctx context.Context, vulns []*model.IDorVulnerabilityInput) ([]*model.VulnerabilityIDs, error) {
	funcName := "IngestVulnerabilities"
	var collectedVulnIDs []*model.VulnerabilityIDs

	ids, txErr := WithinTX(ctx, b.client, func(ctx context.Context) (*[]model.VulnerabilityIDs, error) {
		client := ent.TxFromContext(ctx)
		slc, err := upsertBulkVulnerability(ctx, client, vulns)
		if err != nil {
			return nil, err
		}
		return slc, nil
	})
	if txErr != nil {
		return nil, gqlerror.Errorf("%v :: %s", funcName, txErr)
	}

	for _, vulnIDs := range *ids {
		v := vulnIDs
		collectedVulnIDs = append(collectedVulnIDs, &v)
	}

	return collectedVulnIDs, nil
}

func (b *EntBackend) VulnerabilityList(ctx context.Context, spec model.VulnerabilitySpec, after *string, first *int) (*model.VulnerabilityConnection, error) {
	var afterCursor *entgql.Cursor[uuid.UUID]

	if after != nil {
		globalID := fromGlobalID(*after)
		if globalID.nodeType != vulnerabilityid.Table {
			return nil, fmt.Errorf("after cursor is not type vulnerability but type: %s", globalID.nodeType)
		}
		afterUUID, err := uuid.Parse(globalID.id)
		if err != nil {
			return nil, fmt.Errorf("failed to parse global ID with error: %w", err)
		}
		afterCursor = &ent.Cursor{ID: afterUUID}
	} else {
		afterCursor = nil
	}

	vulnConn, err := b.client.VulnerabilityID.Query().
		Where(vulnerabilityQueryPredicates(spec)...).
		Paginate(ctx, afterCursor, first, nil, nil)
	if err != nil {
		return nil, fmt.Errorf("failed vulnerability query with error: %w", err)
	}

	// if not found return nil
	if vulnConn == nil {
		return nil, nil
	}

	var edges []*model.VulnerabilityEdge
	for _, edge := range vulnConn.Edges {
		edges = append(edges, &model.VulnerabilityEdge{
			Cursor: vulnIDGlobalID(edge.Cursor.ID.String()),
			Node:   toModelVulnerabilityFromVulnerabilityID(edge.Node),
		})
	}

	if vulnConn.PageInfo.StartCursor != nil {
		return &model.VulnerabilityConnection{
			TotalCount: vulnConn.TotalCount,
			PageInfo: &model.PageInfo{
				HasNextPage: vulnConn.PageInfo.HasNextPage,
				StartCursor: ptrfrom.String(vulnIDGlobalID(vulnConn.PageInfo.StartCursor.ID.String())),
				EndCursor:   ptrfrom.String(vulnIDGlobalID(vulnConn.PageInfo.EndCursor.ID.String())),
			},
			Edges: edges,
		}, nil
	} else {
		// if not found return nil
		return nil, nil
	}
}

func (b *EntBackend) Vulnerabilities(ctx context.Context, filter *model.VulnerabilitySpec) ([]*model.Vulnerability, error) {
	if filter == nil {
		filter = &model.VulnerabilitySpec{}
	}
	records, err := getVulnerabilities(ctx, b.client, *filter)
	if err != nil {
		return nil, fmt.Errorf("getVulnerabilities with error: %w", err)
	}
	return toModelVulnerability(records), nil
}

func getVulnerabilities(ctx context.Context, client *ent.Client, filter model.VulnerabilitySpec) (ent.VulnerabilityIDs, error) {

	results, err := client.VulnerabilityID.Query().
		Where(vulnerabilityQueryPredicates(filter)...).
		All(ctx)
	if err != nil {
		return nil, fmt.Errorf("failed vulnerability query with error: %w", err)
	}
	return results, nil
}

func vulnerabilityQueryPredicates(filter model.VulnerabilitySpec) []predicate.VulnerabilityID {
	var where = []predicate.VulnerabilityID{
		optionalPredicate(filter.ID, IDEQ),
	}

	// setting noVuln to true return all packages that have no vulnerabilities
	if filter.NoVuln != nil && *filter.NoVuln {
		filter.Type = ptrfrom.String(NoVuln)
		filter.VulnerabilityID = ptrfrom.String("")
	}

	// setting noVuln to false return all packages with vulnerabilities. This adds a
	// check to make sure type is not equal to `novuln`
	if filter.NoVuln != nil && !*filter.NoVuln {
		where = append(where,
			optionalPredicate(ptrfrom.String(strings.ToLower(NoVuln)), vulnerabilityid.TypeNEQ),
		)
	}

	if filter.Type != nil {
		where = append(where,
			optionalPredicate(ptrfrom.String(strings.ToLower(*filter.Type)), vulnerabilityid.TypeEQ),
		)
	}

	if (filter.Type != nil && strings.ToLower(*filter.Type) != NoVuln) || filter.VulnerabilityID != nil {
		if filter.VulnerabilityID != nil {
			where = append(where,
				optionalPredicate(ptrfrom.String(strings.ToLower(*filter.VulnerabilityID)), vulnerabilityid.VulnerabilityIDEQ),
			)
		}
	}

	return where
}

func upsertBulkVulnerability(ctx context.Context, tx *ent.Tx, vulnInputs []*model.IDorVulnerabilityInput) (*[]model.VulnerabilityIDs, error) {
	batches := chunk(vulnInputs, MaxBatchSize)
	ids := make([]model.VulnerabilityIDs, 0)

	for _, vulns := range batches {
		creates := make([]*ent.VulnerabilityIDCreate, len(vulns))
		for i, vuln := range vulns {
			v := vuln
			vulnIDs := helpers.GetKey[*model.VulnerabilityInputSpec, helpers.VulnIds](v.VulnerabilityInput, helpers.VulnServerKey)
			vulnID := generateUUIDKey([]byte(vulnIDs.VulnerabilityID))
			creates[i] = generateVulnerabilityIDCreate(tx, &vulnID, v)

			ids = append(ids, model.VulnerabilityIDs{
				VulnerabilityTypeID: vulnTypeGlobalID(v.VulnerabilityInput.Type),
				VulnerabilityNodeID: vulnIDGlobalID(vulnID.String())})
		}

		err := tx.VulnerabilityID.CreateBulk(creates...).
			OnConflict(
				sql.ConflictColumns(vulnerabilityid.FieldType, vulnerabilityid.FieldVulnerabilityID),
			).
			DoNothing().
			Exec(ctx)
		if err != nil {
			return nil, errors.Wrap(err, "bulk upsert vulnerability")
		}
	}

	return &ids, nil
}

func generateVulnerabilityIDCreate(tx *ent.Tx, vulnID *uuid.UUID, vulnInput *model.IDorVulnerabilityInput) *ent.VulnerabilityIDCreate {
	return tx.VulnerabilityID.Create().
		SetID(*vulnID).
		SetType(strings.ToLower(vulnInput.VulnerabilityInput.Type)).
		SetVulnerabilityID(strings.ToLower(vulnInput.VulnerabilityInput.VulnerabilityID))
}

func upsertVulnerability(ctx context.Context, tx *ent.Tx, spec model.IDorVulnerabilityInput) (*model.VulnerabilityIDs, error) {
	vulnIDs := helpers.GetKey[*model.VulnerabilityInputSpec, helpers.VulnIds](spec.VulnerabilityInput, helpers.VulnServerKey)
	vulnID := generateUUIDKey([]byte(vulnIDs.VulnerabilityID))

	create := generateVulnerabilityIDCreate(tx, &vulnID, &spec)
	err := create.
		OnConflict(sql.ConflictColumns(vulnerabilityid.FieldType, vulnerabilityid.FieldVulnerabilityID)).
		DoNothing().
		Exec(ctx)

	if err != nil {
		if err != stdsql.ErrNoRows {
			return nil, errors.Wrap(err, "upsert vulnerability")
		}
	}

	return &model.VulnerabilityIDs{
		VulnerabilityTypeID: vulnTypeGlobalID(spec.VulnerabilityInput.Type),
		VulnerabilityNodeID: vulnIDGlobalID(vulnID.String()),
	}, nil
}

func toModelVulnerability(collectedVulnID []*ent.VulnerabilityID) []*model.Vulnerability {
	vulnTypes := map[string][]*model.VulnerabilityID{}

	for _, vulnID := range collectedVulnID {
		typeString := vulnID.Type + "," + vulnTypeGlobalID(vulnID.Type)
		vulnID := &model.VulnerabilityID{
			ID:              vulnIDGlobalID(vulnID.ID.String()),
			VulnerabilityID: vulnID.VulnerabilityID,
		}
		if _, ok := vulnTypes[typeString]; ok {
			vulnTypes[typeString] = append(vulnTypes[typeString], vulnID)
		} else {
			var vulnIDs []*model.VulnerabilityID
			vulnIDs = append(vulnIDs, vulnID)
			vulnTypes[typeString] = vulnIDs
		}
	}
	var vulnerabilities []*model.Vulnerability
	for vulnType, vulnIDs := range vulnTypes {
		typeValues := strings.Split(vulnType, ",")
		vuln := &model.Vulnerability{
			ID:               typeValues[1],
			Type:             typeValues[0],
			VulnerabilityIDs: vulnIDs,
		}
		vulnerabilities = append(vulnerabilities, vuln)
	}
	return vulnerabilities
}

func toModelVulnerabilityID(vulnID *ent.VulnerabilityID) *model.VulnerabilityID {
	return &model.VulnerabilityID{
		ID:              vulnIDGlobalID(vulnID.ID.String()),
		VulnerabilityID: vulnID.VulnerabilityID,
	}
}

func (b *EntBackend) getVulnType(ctx context.Context, nodeID string) (*model.Vulnerability, error) {
	query := b.client.VulnerabilityID.Query().
		Where(vulnerabilityQueryPredicates(model.VulnerabilitySpec{Type: &nodeID})...)

	vulnIDs, err := query.All(ctx)
	if err != nil {
		return nil, fmt.Errorf("failed to get vuln type for node ID: %s with error: %w", nodeID, err)
	}

	if len(vulnIDs) > 0 {
		vulnType := &model.Vulnerability{
			ID:               vulnTypeGlobalID(vulnIDs[0].Type),
			Type:             vulnIDs[0].Type,
			VulnerabilityIDs: []*model.VulnerabilityID{},
		}
		return vulnType, nil
	} else {
		return nil, fmt.Errorf("failed to get vuln type for node ID: %s", nodeID)
	}
}

func (b *EntBackend) vulnTypeNeighbors(ctx context.Context, nodeID string, allowedEdges edgeMap) ([]model.Node, error) {
	var out []model.Node
	if allowedEdges[model.EdgeVulnerabilityTypeVulnerabilityID] {
		query := b.client.VulnerabilityID.Query().
			Where(vulnerabilityQueryPredicates(model.VulnerabilitySpec{Type: &nodeID})...)

		vulnIDs, err := query.All(ctx)
		if err != nil {
			return []model.Node{}, fmt.Errorf("failed to get vuln type for node ID: %s with error: %w", nodeID, err)
		}

		for _, foundVulnID := range vulnIDs {
			out = append(out, &model.Vulnerability{
				ID:   vulnTypeGlobalID(foundVulnID.Type),
				Type: foundVulnID.Type,
				VulnerabilityIDs: []*model.VulnerabilityID{
					{
						ID:              vulnIDGlobalID(foundVulnID.ID.String()),
						VulnerabilityID: foundVulnID.VulnerabilityID,
					},
				},
			})
		}
	}
	return out, nil
}

func (b *EntBackend) vulnIdNeighbors(ctx context.Context, nodeID string, allowedEdges edgeMap) ([]model.Node, error) {
	var out []model.Node

	query := b.client.VulnerabilityID.Query().
		Where(vulnerabilityQueryPredicates(model.VulnerabilitySpec{ID: &nodeID})...)

	if allowedEdges[model.EdgeVulnerabilityCertifyVuln] {
		query.
			WithCertifyVuln(func(q *ent.CertifyVulnQuery) {
				getCertVulnObject(q)
			})
	}
	if allowedEdges[model.EdgeVulnerabilityVulnEqual] {
		query.
			WithVulnEqualVulnA(func(q *ent.VulnEqualQuery) {
				getVulnEqualObject(q)
			}).
			WithVulnEqualVulnB(func(q *ent.VulnEqualQuery) {
				getVulnEqualObject(q)
			})
	}
	if allowedEdges[model.EdgeVulnerabilityCertifyVexStatement] {
		query.
			WithVex(func(q *ent.CertifyVexQuery) {
				getVEXObject(q)
			})
	}
	if allowedEdges[model.EdgeVulnMetadataVulnerability] {
		query.
			WithMetadata(func(q *ent.VulnerabilityMetadataQuery) {
				getVulnMetadataObject(q)
			})
	}

	vulnIDs, err := query.All(ctx)
	if err != nil {
		return []model.Node{}, fmt.Errorf("failed to get vulnerabilityID for node ID: %s with error: %w", nodeID, err)
	}

	for _, foundVulnID := range vulnIDs {
		if allowedEdges[model.EdgeVulnerabilityIDVulnerabilityType] {
			out = append(out, &model.Vulnerability{
				ID:               vulnTypeGlobalID(foundVulnID.Type),
				Type:             foundVulnID.Type,
				VulnerabilityIDs: []*model.VulnerabilityID{},
			})
		}
		for _, certVuln := range foundVulnID.Edges.CertifyVuln {
			out = append(out, toModelCertifyVulnerability(certVuln))
		}
		for _, vulnEqualA := range foundVulnID.Edges.VulnEqualVulnA {
			out = append(out, toModelVulnEqual(vulnEqualA))
		}
		for _, vulnEqualB := range foundVulnID.Edges.VulnEqualVulnB {
			out = append(out, toModelVulnEqual(vulnEqualB))
		}
		for _, certVex := range foundVulnID.Edges.Vex {
			out = append(out, toModelCertifyVEXStatement(certVex))
		}
		for _, meta := range foundVulnID.Edges.Metadata {
			out = append(out, toModelVulnerabilityMetadata(meta))
		}
	}

	return out, nil
}
